// Copyright 2022 Chen Jun
// Licensed under the MIT License.

// OpenCV
#include <opencv2/core.hpp>
#include <opencv2/core/mat.hpp>
#include <opencv2/core/types.hpp>
#include <opencv2/dnn.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/opencv.hpp>

// STL
#include <algorithm>
#include <fstream>
#include <string>
#include <vector>

#include "hnurm_detect/armor.hpp"
#include "hnurm_detect/number_classifier.hpp"

namespace hnurm
{
NumberClassifier::NumberClassifier(
    const std::string              &model_path,
    const std::string              &label_path,
    const double                    threshold,
    const std::vector<std::string> &ignore_classes
)
    : threshold(threshold)
{
    net_ = cv::dnn::readNetFromONNX(model_path);

    std::ifstream label_file(label_path);
    std::string   line;
    while(std::getline(label_file, line))
    {
        class_names_.push_back(line);
    }

    ignore_classes_ = ignore_classes;
}

void NumberClassifier::extractNumbers(const cv::Mat &src, std::vector<Armor> &armors)
{
    // Light length in image
    constexpr int light_length = 12;
    // Image size after warp
    constexpr int warp_height       = 28;
    constexpr int small_armor_width = 32;
    constexpr int large_armor_width = 54;
    // Number ROI size
    const cv::Size roi_size(20, 28);
    for(auto &armor : armors)
    {
        // Warp perspective transform
        cv::Point2f lights_vertices[4]
            = {armor.left_light.bottom, armor.left_light.top, armor.right_light.top, armor.right_light.bottom};

        const int   top_light_y        = (warp_height - light_length) / 2 - 1;
        const int   bottom_light_y     = top_light_y + light_length;
        const int   warp_width         = armor.armor_type == SMALL ? small_armor_width : large_armor_width;
        cv::Point2f target_vertices[4] = {
            cv::Point(0, bottom_light_y),
            cv::Point(0, top_light_y),
            cv::Point(warp_width - 1, top_light_y),
            cv::Point(warp_width - 1, bottom_light_y),
        };
        cv::Mat number_image,number_image_gray;
        auto    rotation_matrix = cv::getPerspectiveTransform(lights_vertices, target_vertices);
        cv::warpPerspective(src, number_image, rotation_matrix, cv::Size(warp_width, warp_height));

        // Get ROI
        number_image = number_image(cv::Rect(cv::Point((warp_width - roi_size.width) / 2, 0), roi_size));

        // Binarize
        cv::cvtColor(number_image, number_image_gray, cv::COLOR_RGB2GRAY);
        cv::threshold(number_image_gray, number_image, 0, 255, cv::THRESH_BINARY | cv::THRESH_OTSU);
        int white = cv::countNonZero(number_image);
        if(white < 40)
        {
            cv::bitwise_not(number_image, number_image);
            number_image=number_image/255;
            cv::multiply(number_image_gray, number_image, number_image_gray);
            cv::threshold(number_image_gray, number_image, 0, 255, cv::THRESH_BINARY | cv::THRESH_OTSU);
        }
        armor.number_img = number_image;
    }
}

void NumberClassifier::classify(std::vector<Armor> &armors,int self_color)
{
    // todo 修改为并行,dnn推理改为blobFromImages()
    for(auto &armor : armors)
    {
        cv::Mat image = armor.number_img.clone();

        // Normalize
        image = image / 255.0;

        // Create blob from image
        cv::Mat blob;
        cv::dnn::blobFromImage(image, blob, 1., cv::Size(20, 28));

        // Set the input blob for the neural network
        net_.setInput(blob);
        // Forward pass the image blob through the model
        cv::Mat outputs = net_.forward();

        // Do softmax
        float   max_prob = *std::max_element(outputs.begin<float>(), outputs.end<float>());
        cv::Mat softmax_prob;
        cv::exp(outputs - max_prob, softmax_prob);
        float sum = static_cast<float>(cv::sum(softmax_prob)[0]);
        softmax_prob /= sum;

        double    confidence;
        cv::Point class_id_point;
        minMaxLoc(softmax_prob.reshape(1, 1), nullptr, &confidence, nullptr, &class_id_point);
        int label_id = class_id_point.x;

        armor.confidence = (float)confidence;
        armor.number     = class_names_[label_id];
        armor.idx        = label_id;

        std::stringstream result_ss;
        result_ss << armor.number << ": " << std::fixed << std::setprecision(1) << armor.confidence * 100.0 << "%";
        armor.classification_result = result_ss.str();
    }

    // 筛选一些装甲板
            armors.erase(
                std::remove_if(
                        armors.begin(), armors.end(),
                        [this](const Armor &armor)
                        {

//                            去除置信度太低的
                            if (armor.confidence < threshold) {
                                return true;
                            }

//                            去除有意忽略的类型
                            for (const auto &ignore_class: ignore_classes_) {
                                if (armor.number == ignore_class) {
                                    return true;
                                }
                            }



//                            如果是大装甲板，去除"outpost"（前哨站），"2"（工程），"guard"（哨兵），“3”、“5”（其中平步4是大装甲板）
//                            如果是小装甲板，去除"1"（英雄）和"base"（基地）
                            bool mismatch_armor_type = false;
                            if (armor.armor_type == ArmorType::LARGE) {
                                mismatch_armor_type =
                                        armor.number == "outpost" || armor.number == "2" || armor.number == "guard";
                            } else if (armor.armor_type == ArmorType::SMALL) {
                                mismatch_armor_type = armor.number == "1" || armor.number == "base";
                            }
                            return mismatch_armor_type;

                        }),

                armors.end()
                );
}

}  // namespace hnurm
